#!/usr/bin/python
"""Calibrate the compass MinIMU, including roll.

Two steps:
- Turn boat around while level to get min/max magnetic field readings.
- Roll boat side to side at constant heading to calculate roll compensation.
"""
from __future__ import print_function

import csv
import curses
from datetime import datetime
import numpy as np
import os
from scipy.optimize import leastsq
import socket
import sys
import time
import rospy

my_dir = os.path.dirname(__file__)
robot_src_dir = os.path.abspath(os.path.join(my_dir, '../src/sailing_robot/src'))
sys.path.append(robot_src_dir)

from sailing_robot.imu_utils import ImuReader
from curses_imu import IMUDisplay, pitch_roll

IMU_BUS = 1
LGD = 0x6b #Device I2C slave address
LSM = 0x1d #Device I2C slave address

try:   # Check we can talk to ROS before trying to calibrate
    rospy.get_param('/rosversion')
except socket.error:
    sys.exit("Can't connect to parameter server. Ensure roscore is running.")

imu = ImuReader(IMU_BUS, LSM, LGD)
imu.check_status()
imu.configure_for_reading()


def interactive_calibration(stdscr):
    display = IMUDisplay(stdscr)
    stdscr.addstr(curses.LINES-2, 0, '1. The minute waltz')
    stdscr.addstr(curses.LINES-1, 0, 'Press enter to start')
    
    key = ''
    while key != '\n':
        key = stdscr.getkey()

    N = 300
    stdscr.addstr(curses.LINES-1, 0, '   0/{}                 '.format(N))

    ts = datetime.utcnow().strftime('%Y-%m-%dT%H.%M.%S')
    filename_level = 'calibration_level_{}.csv'.format(ts)
    with open('latest_calibration_time', 'w') as f:
        f.write(ts)
    f = open(filename_level, 'w')
    cw = csv.writer(f)
    cw.writerow(['mag_x', 'mag_y', 'mag_z', 'acc_x', 'acc_y', 'acc_z'])

    data_X = []
    data_Y = []

    for i in range(N):
        try:
            time.sleep(0.1)
            
            magx, magy, magz = imu.read_mag_field()
            accx, accy, accz = imu.read_acceleration()
            cw.writerow([magx, magy, magz, accx, accy, accz])

            data_X.append(magx)
            data_Y.append(magy)

            display.update_mag(magx, magy, magz)
            display.update_acc(accx, accy, accz)
            display.update_pitch_roll(*pitch_roll(accx, accy, accz))
            stdscr.addstr(curses.LINES-1, 0, "{:>4}".format(i))
            stdscr.refresh()
        except KeyboardInterrupt:
            break

    f.close()

    minx = min(data_X)
    miny = min(data_Y)

    maxx = max(data_X)
    maxy = max(data_Y)

    offset_X = (maxx + minx)/2
    offset_Y = (maxy + miny)/2

    range_X = maxx - minx
    range_Y = maxy - miny

    stdscr.addstr(curses.LINES-2, 0, '2. Rock me baby like a wagon wheel')
    stdscr.addstr(curses.LINES-1, 0, 'Press enter to start')
    
    key = ''
    while key != '\n':
        key = stdscr.getkey()

    N = 300
    stdscr.addstr(curses.LINES-1, 0, '   0/{}                 '.format(N))

    pitches = []
    rolls = []
    mag_x = []
    mag_y = []
    mag_z = []

    filename_roll = 'calibration_roll_{}.csv'.format(ts)
    f = open(filename_roll, 'w')
    cw = csv.writer(f)
    cw.writerow(['mag_x', 'mag_y', 'mag_z', 'acc_x', 'acc_y', 'acc_z'])

    for i in range(N):
        try:
            time.sleep(0.1)
            accx, accy, accz = imu.read_acceleration()
            pitch, roll = pitch_roll(accx, accy, accz)
            mag_xyz = imu.read_mag_field()
            cw.writerow(list(mag_xyz) + [accx, accy, accz])
            
            pitches.append(pitch)
            rolls.append(roll)
            mag_x.append(mag_xyz[0])
            mag_y.append(mag_xyz[1])
            mag_z.append(mag_xyz[2])
            
            display.update_mag(*mag_xyz)
            display.update_acc(accx, accy, accz)
            display.update_pitch_roll(*pitch_roll(accx, accy, accz))
            stdscr.addstr(curses.LINES-1, 0, "{:>4}".format(i))
            stdscr.refresh()
        except KeyboardInterrupt:
            break

    f.close()
    
    adjusted_x = (np.array(mag_x) - offset_X) / range_X
    adjusted_y = (np.array(mag_y) - offset_Y) / range_Y
    raw_z = np.array(mag_z)

    roll_d = np.array(rolls)
    roll_r = np.radians(roll_d)
    pitch_r = np.radians(pitches)

    # Take our correct y field as the points where we're within 3 degrees of level.
    y_flat = adjusted_y[(-3 < roll_d) & (roll_d < +3)].mean()
    
    def mag_y_comp_residuals(p):
        MagY_comp = (adjusted_x * np.sin(roll_r) * np.sin(pitch_r)) +\
         (adjusted_y * np.cos(roll_r)) - (((raw_z - p[0]) / p[1])* np.sin(roll_r) * np.cos(pitch_r)) 
        return MagY_comp - y_flat

    res, ier = leastsq(mag_y_comp_residuals, (1, 1))
    # According to the docs, an integer code between 1 and 4 (inclusive) indicates
    # success.
    assert 1 <= ier <= 4

    offset_Z, range_Z = res
    offset_Z = offset_Z.item()
    range_Z = range_Z.item()

    rospy.set_param('/calibration/compass', {'XOFFSET': offset_X,
                                              'YOFFSET': offset_Y,
                                              'ZOFFSET': offset_Z,
                                              'XSCALE': range_X,
                                              'YSCALE': range_Y,
                                              'ZSCALE': range_Z,
                                              })

    return (offset_X, offset_Y, offset_Z), (range_X, range_Y, range_Z)


offset_xyz, range_xyz = curses.wrapper(interactive_calibration)

print("XOFFSET = " + str(offset_xyz[0]))
print("YOFFSET = " + str(offset_xyz[1]))
print("ZOFFSET = " + str(offset_xyz[2]))
print("XSCALE = " + str(range_xyz[0]))
print("YSCALE = " + str(range_xyz[1]))
print("ZSCALE = " + str(range_xyz[2]))
